clear
clc
load('hemisphere_dataset_summary.mat')
alg1_colors = [[1,0,1];[0,1,0]];
alg2_colors = [[1,0.5,1];[0.5,1,0.5]];
num_annotators = 12;
accuracy_tn = zeros(2,num_annotators);
accuracy_roc = zeros(2,num_annotators);
vals = [1,2,3,4,7,9,10,11,13,14,15,18];

for i=1:num_annotators
    if i<5
        metrics = metrics_all{1};
    elseif i<9
        metrics = metrics_all{2};
    else
        metrics = metrics_all{3};
    end
    traditional_ind = [1, 2, 35, 37, 42, 43, 45, 46, 51:66];
    features = metrics';
    infs = isinf(features);
    features(infs) = 0;
    features = zscore(features, 0, 1);
    features(infs) = 0;
    features_tr = features(:,traditional_ind);

    choices = choices_all{vals(i)};

    num_cells = size(features,1);
    [~,preds] = get_classifier_preds(features,choices);
    [~,preds_tr] = get_classifier_preds(features_tr,choices);
    preds = mean(preds,1);
    preds_tr = mean(preds_tr,1);

    [fp_temp,tp_temp] = calculate_roc(choices,preds_tr);
    accuracy_roc(1,i) = trapz(fp_temp,tp_temp); 
    [fp_temp,tp_temp] = calculate_roc(choices,preds);
    accuracy_roc(2,i) = trapz(fp_temp,tp_temp); 

    evaluators = zeros(2,num_cells);
    evaluators(1,:) = (preds_tr>0.5)*2-1;
    evaluators(2,:) = (preds>0.5)*2-1;
 
    for j=1:2
        ind = find(choices == -1);
        accuracy_tn(j,i) = mean(choices(ind) == evaluators(j,ind));
    end

end


%%
figure
groups = accuracy_tn;
for i=1:2
    x = zeros(1,size(groups,2))+i;
    y = groups(i,:);
    bar(i, mean(y),'EdgeColor', 'none', 'FaceAlpha', 0.7);
    hold on
end
for i=1:2
    x = zeros(1,size(groups,2))+i;
    y = groups(i,:);
    scatter(x,y,30,'black','filled')
    hold on
end
plot([1,2], groups([1,2],:), 'Color', 'black', 'linestyle', '-','LineWidth',1)
hold on
xticks([1 2])
xticklabels(["Traditional","All features"])
ylabel('Accuracy (TN)')
ylim([0,1])
p = signrank(groups(1,:),groups(2,:));
sigstar([1,2],p)
set_common_gca_props

exportgraphics(gcf,'Fig3C1.eps','ContentType','vector')

%%
figure
groups = accuracy_roc;
for i=1:2
    x = zeros(1,size(groups,2))+i;
    y = groups(i,:);
    bar(i, mean(y),'EdgeColor', 'none', 'FaceAlpha', 0.7);
    hold on
end
for i=1:2
    x = zeros(1,size(groups,2))+i;
    y = groups(i,:);
    scatter(x,y,30,'black','filled')
    hold on
end
plot([1,2], groups([1,2],:), 'Color', 'black', 'linestyle', '-','LineWidth',1)
hold on
xticks([1 2])
xticklabels(["Traditional","All features"])
ylabel('ROC')
ylim([0,1])
p = signrank(groups(1,:),groups(2,:));
sigstar([1,2],p)
set_common_gca_props

exportgraphics(gcf,'Fig3C2.eps','ContentType','vector')


%%

function [fp,tp] = calculate_roc(gt,preds)
        
    thr = linspace(1,0,100);
    tp = zeros(1,100);
    fp = zeros(1,100);
    for i=1:100
        predictions=(preds>thr(i));
        chosen = gt(predictions);
        tp(i)= sum(chosen == 1) / sum(gt == 1);
        fp(i)= sum(chosen == -1) / sum(gt == -1);
    end

   
end


function set_common_gca_props
ax = gca;
ax.Color = 'none';
ax.LineWidth = 1.5;
ax.TickLength = [0.03, 0.03];
pbaspect([1 3 1])
end